{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 实战案例：基于多模态因子双线性池化（MFB）的视觉问答\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 读取数据\n",
    "\n",
    "### 下载数据集\n",
    "\n",
    "使用的数据集为VQA v2([下载地址](https://visualqa.org/download.html))\n",
    "- 数据集中的图像来自于MS COCO\n",
    "- - 下载[bottom up attention模型](https://github.com/peteanderson80/bottom-up-attention)提供的[图像区域表示](https://storage.googleapis.com/up-down-attention/trainval_36.zip)。该文件解压后的tsv文件包含了MS COCO训练集和验证集中所有图片的36个检测框的视觉表示\n",
    "- - 这里将tsv文件放在目录../data/vqa/coco下\n",
    "\n",
    "- 对于问题和回答，需要下载四个文件：训练标注集（训练回答集）、验证标注集（训练回答集）、训练问题集和验证问题集\n",
    "- - 将这四个文件解压后，得到4个json格式的文件，并将其放在目录../data/vqa下的vqa2文件夹里\n",
    "- - 由于测试集的标注集未公开，因此这里仅下载训练集和验证集\n",
    "- - 数据集包含443,757个训练问题和214,354个验证问题，每个问题对应10个人工标注的答案"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### 整理数据集\n",
    "\n",
    "数据集下载完成后，我们需要对其进行处理，以适合之后构造的PyTorch数据集类读取。\n",
    "\n",
    "对于图像，我们按照SCAN中介绍的方式，将每张图片的36个检测框表示存储为单个npy格式文件，并将文件路径记录在数据json文件中。为了后续的数据分析，我们还将检测框的位置信息也以npy格式存储。json文件中的路径仅存储文件名前缀，加上后缀'.npy'为图像特征，加上后缀'.box.npy'为检测框特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import base64\n",
    "import csv\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "from os.path import join as pjoin\n",
    "import sys\n",
    "\n",
    "csv.field_size_limit(sys.maxsize)\n",
    "\n",
    "def resort_image_feature():\n",
    "    image_feature_path = '../data/vqa/coco/trainval_resnet101_faster_rcnn_genome_36.tsv' \n",
    "    feature_folder = '../data/vqa/coco/image_box_features'\n",
    "    if not os.path.exists(feature_folder):\n",
    "        os.makedirs(feature_folder)\n",
    "        \n",
    "    imgid2feature = {}\n",
    "    imgid2box = {}\n",
    "    FIELDNAMES = ['image_id', 'image_h', 'image_w', 'num_boxes', 'boxes', 'features']\n",
    "    with open(image_feature_path, 'r') as tsv_in_file:\n",
    "        reader = csv.DictReader(tsv_in_file, delimiter='\\t', fieldnames = FIELDNAMES)\n",
    "        for item in reader:\n",
    "            item['num_boxes'] = int(item['num_boxes'])\n",
    "            for field in ['boxes', 'features']:\n",
    "                buf = base64.b64decode(item[field])\n",
    "                temp = np.frombuffer(buf, dtype=np.float32)\n",
    "                item[field] = temp.reshape((item['num_boxes'],-1))\n",
    "            np.save(pjoin(feature_folder, item['image_id']+'.jpg.npy'), item['features'])\n",
    "            np.save(pjoin(feature_folder, item['image_id']+'.jpg.box.npy'), item['boxes'])\n",
    "\n",
    "resort_image_feature()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "对于回答，我们取出现频次最高的max_ans_count个回答，将任务转化为max_ans_count个类的分类任务，并将每个问题的多个回答转化为列表。\n",
    "\n",
    "对于问题，我们首先构建词典，然后根据词典将问题转化为向量，并过滤掉所有回答都不在高频回答中的问题样本，"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict, Counter\n",
    "import json\n",
    "import os\n",
    "from os.path import join as pjoin\n",
    "import random \n",
    "import re\n",
    "import torch\n",
    "from PIL import Image\n",
    "\n",
    "def tokenize_mcb(s):\n",
    "    \"\"\"\n",
    "    问题词元化（tokenization）函数，来源于https://github.com/Cadene/block.bootstrap.pytorch/blob/master/block/datasets/vqa_utils.py\n",
    "    \"\"\"\n",
    "    t_str = s.lower()\n",
    "    for i in [r'\\?',r'\\!',r'\\'',r'\\\"',r'\\$',r'\\:',r'\\@',r'\\(',r'\\)',r'\\,',r'\\.',r'\\;']:\n",
    "        t_str = re.sub( i, '', t_str)\n",
    "    for i in [r'\\-',r'\\/']:\n",
    "        t_str = re.sub( i, ' ', t_str)\n",
    "    q_list = re.sub(r'\\?','',t_str.lower()).split(' ')\n",
    "    q_list = list(filter(lambda x: len(x) > 0, q_list))\n",
    "    return q_list\n",
    "\n",
    "def tokenize_questions(questions):\n",
    "    for item in questions:\n",
    "        item['question_tokens'] = tokenize_mcb(item['question'])\n",
    "    return questions\n",
    "\n",
    "def annotations_in_top_answers(annotations, questions, ans_vocab):\n",
    "    new_anno = []\n",
    "    new_ques = []\n",
    "    assert len(annotations) == len(questions)\n",
    "    for anno,ques in zip(annotations, questions):\n",
    "        if anno['multiple_choice_answer'] in ans_vocab:\n",
    "            new_anno.append(anno)\n",
    "            new_ques.append(ques)\n",
    "    return new_anno, new_ques\n",
    "    \n",
    "def encode_questions(questions, vocab):\n",
    "    for item in questions:\n",
    "        item['question_idx'] = [vocab.get(w, vocab['<unk>']) for w in item['question_tokens']]\n",
    "    return questions\n",
    "\n",
    "def encode_answers(annotations, vocab):\n",
    "    # 记录回答的频率\n",
    "    for item in annotations:\n",
    "        item['answer_list'] = []\n",
    "        answers = [a['answer'] for a in item['answers']]\n",
    "        for ans in answers:\n",
    "            if ans in vocab:\n",
    "                item['answer_list'].append(vocab[ans])\n",
    "    return annotations\n",
    "\n",
    "def create_dataset(dataset='flickr8k',\n",
    "                   max_ans_count=1000, \n",
    "                   min_word_count=10):\n",
    "    \"\"\"\n",
    "    参数：\n",
    "        dataset：数据集名称\n",
    "        max_ans_count：取训练集中最高频的1000个答案\n",
    "        min_word_count：仅考虑在训练集中问题文本里出现10次及以上的词\n",
    "    输出：\n",
    "        一个词典文件： vocab.json\n",
    "        两个数据集文件： train_data.json、 val_data.json\n",
    "    \"\"\"\n",
    "    dir_vqa2 = '../data/vqa/vqa2'\n",
    "    dir_processed = os.path.join(dir_vqa2, 'processed')\n",
    "    dir_ann = pjoin(dir_vqa2, 'raw', 'annotations')\n",
    "    path_train_ann = pjoin(dir_ann, 'mscoco_train2014_annotations.json')\n",
    "    path_train_ques = pjoin(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')\n",
    "    path_val_ann = pjoin(dir_ann, 'mscoco_val2014_annotations.json')\n",
    "    path_val_ques = pjoin(dir_ann, 'OpenEnded_mscoco_val2014_questions.json')\n",
    "\n",
    "    # 读取回答和问题\n",
    "    train_anno = json.load(open(path_train_ann))['annotations']\n",
    "    train_ques = json.load(open(path_train_ques))['questions']\n",
    "    \n",
    "    val_anno = json.load(open(path_val_ann))['annotations']\n",
    "    val_ques = json.load(open(path_val_ques))['questions']\n",
    "    \n",
    "    # 取出现频次最高的nans个回答，将任务转化为nans个类的分类问题\n",
    "    ans2ct = defaultdict(int)\n",
    "    for item in train_anno:\n",
    "        ans = item['multiple_choice_answer'] \n",
    "        ans2ct[ans] += 1\n",
    "    ans_ct = sorted(ans2ct.items(), key=lambda item:item[1], reverse=True)\n",
    "    ans_vocab = [ans_ct[i][0] for i in range(max_ans_count)] \n",
    "    ans_vocab = {a:i for i,a in enumerate(ans_vocab)}\n",
    "    train_anno = encode_answers(train_anno, ans_vocab)\n",
    "    val_anno = encode_answers(val_anno, ans_vocab)\n",
    "    # 处理问题文本\n",
    "    train_ques = tokenize_questions(train_ques)\n",
    "    val_ques = tokenize_questions(val_ques)\n",
    "    # 保留高频词\n",
    "    ques_vocab = Counter()\n",
    "    for item in train_ques:\n",
    "        ques_vocab.update(item['question_tokens'])\n",
    "    ques_vocab = [w for w in ques_vocab.keys() if ques_vocab[w] > min_word_count]\n",
    "    ques_vocab = {q:i for i,q in enumerate(ques_vocab)}\n",
    "    ques_vocab['<unk>'] = len(ques_vocab)\n",
    "    train_ques = encode_questions(train_ques, ques_vocab)\n",
    "    val_ques = encode_questions(val_ques, ques_vocab)\n",
    "    # 过滤掉所有回答都不在高频回答中的数据\n",
    "    train_anno, train_ques = annotations_in_top_answers(\n",
    "            train_anno, train_ques, ans_vocab)\n",
    "\n",
    "    if not os.path.exists(dir_processed):\n",
    "        os.makedirs(dir_processed)\n",
    "    # 存储问题和回答词典\n",
    "    with open(pjoin(dir_processed, 'vocab.json'), 'w') as fw:\n",
    "        json.dump({'ans_vocab': ans_vocab, 'ques_vocab': ques_vocab}, fw)\n",
    "    # 存储数据\n",
    "    with open(pjoin(dir_processed, 'train_data.json'), 'w') as fw:\n",
    "        json.dump({'annotations':train_anno, 'questions':train_ques}, fw)\n",
    "    with open(pjoin(dir_processed, 'val_data.json'), 'w') as fw:\n",
    "        json.dump({'annotations':val_anno, 'questions':val_ques}, fw)\n",
    "    \n",
    "create_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "在调用该函数生成需要的格式的数据集文件之后，我们可以展示其中一条数据，简单验证下数据的格式是否和我们预想的一致。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import json\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "from os.path import join as pjoin\n",
    "from PIL import Image\n",
    "\n",
    "data_dir = '../data/vqa/vqa2/'\n",
    "dir_processed = pjoin(data_dir, 'processed')\n",
    "rcnn_dir='../data/vqa/coco/image_box_features/'\n",
    "image_dir = '../data/vqa/coco/raw/val2014/'\n",
    "\n",
    "vocab = json.load(open(pjoin(dir_processed, 'vocab.json'), 'r'))\n",
    "dataset = json.load(open(pjoin(dir_processed, 'val_data.json'), 'r'))\n",
    "\n",
    "idx2ans = {i:a for a,i in vocab['ans_vocab'].items()}\n",
    "idx2ques = {i:q for q,i in vocab['ques_vocab'].items()}\n",
    "\n",
    "# 打印验证集中第10000个样本\n",
    "idx = 10000\n",
    "# 读取问题\n",
    "question = dataset['questions'][idx]\n",
    "q_text = ' '.join([idx2ques[token] for token in question['question_idx']])\n",
    "# 读取回答\n",
    "annotation = dataset['annotations'][idx]\n",
    "a_text = '/'.join([idx2ans[token] for token in annotation['answer_list']])\n",
    "\n",
    "image_name = 'COCO_val2014_%012d.jpg'%(question['image_id'])\n",
    "content_img = Image.open(pjoin(image_dir, image_name))\n",
    "fig = plt.imshow(content_img)\n",
    "feats = np.load(pjoin(rcnn_dir, '{}.jpg.box.npy').format(question['image_id']))\n",
    "for i in range(feats.shape[0]):\n",
    "    bbox = feats[i,:]\n",
    "    color = 'red'\n",
    "    if i > 3:\n",
    "        color = 'blue'\n",
    "    fig.axes.add_patch(plt.Rectangle(\n",
    "        xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],\n",
    "        fill=False, edgecolor=color, linewidth=1))\n",
    "\n",
    "print(image_name)\n",
    "print('question: ', q_text)\n",
    "print('answer: ', a_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### 定义数据集类\n",
    "\n",
    "在准备好的数据集的基础上，我们需要进一步定义PyTorch Dataset类，以使用PyTorch DataLoader类按批次产生数据。PyTorch中仅预先定义了图像、文本和语音的单模态任务中常见的数据集类。因此，我们需要定义自己的数据集类。\n",
    "\n",
    "在PyTorch中定义数据集类非常简单，仅需要继承torch.utils.data.Dataset类，并实现\\_\\_getitem\\_\\_和\\_\\_len\\_\\_两个函数即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from argparse import Namespace\n",
    "import collections\n",
    "import numpy as np\n",
    "import skipthoughts\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "def collate_fn(batch):\n",
    "    # 对一个批次的数据进行预处理\n",
    "    max_question_length = max([len(item['question']) for item in batch])\n",
    "    batch_size = len(batch)\n",
    "    imgs = torch.zeros(batch_size, batch[0]['image_feat'].shape[0], batch[0]['image_feat'].shape[1])\n",
    "    ques = torch.zeros(batch_size, max_question_length, dtype=torch.long)\n",
    "    ans = torch.zeros(batch_size, 1000)\n",
    "    lens = torch.zeros(batch_size, dtype=torch.long)\n",
    "    for i,item in enumerate(batch):\n",
    "        imgs[i] = torch.from_numpy(item['image_feat'])\n",
    "        ques[i, :item['question'].shape[0]] = item['question']\n",
    "        for answer in item['answers']:\n",
    "            ans[i, answer] += 1\n",
    "        lens[i] = item['length']\n",
    "    return (imgs, ques, ans, lens)\n",
    "\n",
    "class VQA2Dataset(Dataset):\n",
    "\n",
    "    def __init__(self,\n",
    "            data_dir='../data/vqa/vqa2/',\n",
    "            rcnn_dir='../data/vqa/coco/image_box_features/',\n",
    "            split='train',\n",
    "            samplingans=True):\n",
    "        \"\"\"\n",
    "        参数：\n",
    "            samplingans: 决定返回的回答数据。取值为True，则从回答列表中按照回答出现的概率采样一个回答；取值为False，则为回答列表。\n",
    "        \"\"\"    \n",
    "        super(VQA2Dataset, self).__init__()\n",
    "        self.rcnn_dir = rcnn_dir\n",
    "        self.samplingans = samplingans\n",
    "        self.split = split\n",
    "        dir_processed = os.path.join(data_dir, 'processed')\n",
    "        if split == 'train':\n",
    "            self.dataset = json.load(open(pjoin(dir_processed, 'train_data.json'), 'r'))\n",
    "        elif split == 'val':\n",
    "            self.dataset = json.load(open(pjoin(dir_processed, 'val_data.json'), 'r'))\n",
    "        self.dataset_size = len(self.dataset['questions'])\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        item = {}\n",
    "        item['index'] = index\n",
    "\n",
    "        # 读取问题\n",
    "        question = self.dataset['questions'][index]\n",
    "        item['question'] = torch.LongTensor(question['question_idx'])\n",
    "        item['length'] = torch.LongTensor([len(question['question_idx'])])\n",
    "        # 读取图像检测框特征\n",
    "        item['image_feat'] = np.load(os.path.join(self.rcnn_dir, '{}.jpg.npy'.format(question['image_id'])))\n",
    "        # 读取回答\n",
    "        annotation = self.dataset['annotations'][index]\n",
    "        if 'train' in self.split and self.samplingans:\n",
    "            item['answers'] = [random.choice(annotation['answer_list'])]\n",
    "        else:\n",
    "            item['answers'] = annotation['answer_list']\n",
    "        return item\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.dataset_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### 批量读取数据\n",
    "\n",
    "利用刚才构造的数据集类，借助DataLoader类构建能够按批次产生训练、验证和测试数据的对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mktrainval(data_dir, image_feat_dir, batch_size, workers=0):\n",
    "    train_set = VQA2Dataset(data_dir, image_feat_dir, split='train', samplingans=True)\n",
    "    valid_set = VQA2Dataset(data_dir, image_feat_dir, split='val', samplingans=False)\n",
    "   \n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "                        train_set, batch_size=batch_size, \n",
    "                        shuffle=True, num_workers=workers, \n",
    "                        pin_memory=True, collate_fn=collate_fn)\n",
    "    valid_loader = torch.utils.data.DataLoader(\n",
    "                        valid_set, batch_size=batch_size, \n",
    "                        shuffle=False, num_workers=workers, \n",
    "                        pin_memory=True, drop_last=False, collate_fn=collate_fn)\n",
    "\n",
    "    return train_loader, valid_loader    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 定义模型\n",
    "\n",
    "\n",
    "<img src='img/mf-mfbvqa-framework.png' style='float: middle;' width=800 height=800>\n",
    "\n",
    "MFBVQA模型主要包含两个模块：\n",
    "\n",
    "- 注意力跨模态对齐模块\n",
    "- - 使用问题表示作为查询，图像的局部表示作为键和值，获得问题和图像对齐的表示。\n",
    "- - 形式上，该表示为图像局部表示的加权求和的结果，权重则代表了图像区域和该问题的关联程度。\n",
    "- - 这里使用的是多头注意力，且注意力评分函数中计算查询和键的关联时，使用了MFB融合操作。\n",
    "\n",
    "- 双线性融合模块\n",
    "- - 使用MFB融合问题对齐前后的表示，获得最终的融合表示。\n",
    "\n",
    "下面我们将首先实现MFB融合操作，然后实现基于MFB融合的注意力跨模态对齐，最后借助注意力跨模态对齐和MFB融合实现MFBVQA模型。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### MFB融合\n",
    "\n",
    "下面展示MFB融合操作的实现。该函数既支持两个向量融合，也支持两组向量的融合。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MFBFusion(nn.Module):\n",
    "    def __init__(self, input_dim1, input_dim2, hidden_dim, R):\n",
    "        '''\n",
    "        参数：\n",
    "            input_dim1: 第一个待融合表示的维度\n",
    "            input_dim2: 第二个待融合表示的维度\n",
    "            hidden_dim: 融合后的表示的维度\n",
    "            R: MFB所使用的低秩矩阵的维度\n",
    "        '''\n",
    "        super(MFBFusion, self).__init__()\n",
    "        self.input_dim1 = input_dim1\n",
    "        self.input_dim2 = input_dim2\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.R = R\n",
    "        self.linear1 = nn.Linear(input_dim1, hidden_dim * R)\n",
    "        self.linear2 = nn.Linear(input_dim2, hidden_dim * R)\n",
    "\n",
    "    def forward(self, inputs1, inputs2):\n",
    "        '''\n",
    "        参数：\n",
    "            inputs1: (batch_size, input_dim1) 或 (batch_size, num_region, input_dim1)\n",
    "            inputs2: (batch_size, input_dim2) 或 (batch_size, num_region, input_dim2)\n",
    "        '''\n",
    "        # -> total: (batch_size, hidden_dim) 或 (batch_size, num_region, hidden_dim)\n",
    "        num_region = 1\n",
    "        if inputs1.dim() == 3:\n",
    "            num_region = inputs1.size(1)\n",
    "        h1 = self.linear1(inputs1)\n",
    "        h2 = self.linear2(inputs2)\n",
    "        z = h1 * h2\n",
    "        z = z.view(z.size(0), num_region, self.hidden_dim, self.R)\n",
    "        z = z.sum(3).squeeze(1)\n",
    "        return z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### 注意力跨模态对齐模块\n",
    "\n",
    "下面展示了多头交叉注意力的实现。其中注意力得分 $\\alpha$ 是使用MFB操作融合查询和键的结果。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadATTN(nn.Module):\n",
    "    def __init__(self, query_dim, kv_dim, mfb_input_dim, mfb_hidden_dim, num_head, att_dim):\n",
    "        \"\"\"\n",
    "        参数：\n",
    "            query_dim：问题表示（查询）的维度\n",
    "            kv_dim：图像区域表示（键和值）的维度\n",
    "            mfb_input_dim：融合操作的输入的维度\n",
    "            mfb_hidden_dim：融合操作的输出的维度\n",
    "            num_head：多头交叉注意力的头数\n",
    "            att_dim：多头交叉注意力的输出表示（对齐后的表示）维度\n",
    "        \"\"\"\n",
    "        super(MultiHeadATTN, self).__init__()\n",
    "        assert att_dim % num_head == 0\n",
    "        self.num_head = num_head\n",
    "        self.att_dim = att_dim\n",
    "\n",
    "        self.attn_w_1_q = nn.Sequential(\n",
    "                            nn.Dropout(0.5),\n",
    "                            nn.Linear(query_dim, mfb_input_dim),\n",
    "                            nn.ReLU()\n",
    "                          )\n",
    "        self.attn_w_1_k = nn.Sequential(\n",
    "                            nn.Dropout(0.5),\n",
    "                            nn.Linear(kv_dim, mfb_input_dim),\n",
    "                            nn.ReLU()\n",
    "                          )\n",
    "        self.attn_score_fusion = MFBFusion(mfb_input_dim, mfb_input_dim, mfb_hidden_dim, 1)\n",
    "        self.attn_score_mapping = nn.Sequential(\n",
    "                            nn.Dropout(0.5),\n",
    "                            nn.Linear(mfb_hidden_dim, num_head)\n",
    "                          )\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "        # 对齐后的表示计算流程\n",
    "        self.align_q = nn.ModuleList([nn.Sequential(\n",
    "                            nn.Dropout(0.5),\n",
    "                            nn.Linear(kv_dim, int(att_dim / num_head)),\n",
    "                            nn.Tanh()\n",
    "                       ) for _ in range(num_head)])\n",
    "        \n",
    "    def forward(self, query, key_value):\n",
    "        \"\"\"\n",
    "        参数：\n",
    "          query: (batch_size, q_dim)\n",
    "          key_value: (batch_size, num_region, kv_dim)\n",
    "        \"\"\"\n",
    "        #（1）使用全连接层将Q、K、V转化为向量\n",
    "        num_region = key_value.shape[1]\n",
    "        # -> (batch_size, num_region, mfb_input_dim)\n",
    "        q = self.attn_w_1_q(query).unsqueeze(1).repeat(1,num_region,1)\n",
    "        # -> (batch_size, num_region, mfb_input_dim)\n",
    "        k = self.attn_w_1_k(key_value)\n",
    "        #（2）计算query和key的相关性，实现注意力评分函数\n",
    "        # -> (batch_size, num_region, num_head)\n",
    "        alphas = self.attn_score_fusion(q, k)\n",
    "        alphas = self.attn_score_mapping(alphas)\n",
    "        #（3）归一化相关性分数\n",
    "        # -> (batch_size, num_region, num_head)\n",
    "        alphas = self.softmax(alphas)\n",
    "        #（4）计算输出\n",
    "        # (batch_size, num_region, num_head) (batch_size, num_region, key_value_dim)\n",
    "        # -> (batch_size, num_head, key_value_dim)\n",
    "        output = torch.bmm(alphas.transpose(1,2), key_value)\n",
    "        # 最终再对每个头的输出进行一次转换，并拼接所有头的转换结果作为注意力输出\n",
    "        list_v = [e.squeeze() for e in torch.split(output, 1, dim=1)]\n",
    "        alpha = torch.split(alphas, 1, dim=2)\n",
    "        align_feat = torch.cat([self.align_q[head_id](x_v) for head_id, x_v in enumerate(list_v)], 1)\n",
    "        return align_feat, alpha\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### 模型\n",
    "\n",
    "利用上述MFB融合操作和多头自注意力模块的实现，我们可以轻松的实现MFBVQA模型。模型的输入是图像的区域表示和问题。对于问题的表示，模型使用预训练的Skip-thoughts向量  :cite:`Ki.Zh.Sa.ea.2015` 作为问题的整体表示。这里，Skip-thoughts向量提取模型是使用句子作为输入，句子的上下文句子作为监督信息训练而得。\n",
    "<!-- 如图 2-5 所示，输入\"I could see the cat on the steps\"，模型首先通过 GRU 编码句子，然后再使用 GRU 解码预测输出这句话可能的上一句\"I got back home\"和可能的下一句\"This was strange\"。 -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MFBVQAModel(nn.Module):\n",
    "    def __init__(self, vocab_words, qestion_dim, image_dim, \n",
    "                       attn_mfb_input_dim, attn_mfb_hidden_dim, \n",
    "                       attn_num_head, attn_output_dim, \n",
    "                       fusion_q_feature_dim, fusion_mfb_hidden_dim,\n",
    "                       num_classes):\n",
    "        super(MFBVQAModel, self).__init__()\n",
    "\n",
    "        # 文本表示提取器\n",
    "        self.text_encoder = skipthoughts.BayesianUniSkip('../data/vqa/skipthoughts/',\n",
    "                                    vocab_words,\n",
    "                                    dropout=0.25,\n",
    "                                    fixed_emb=False)\n",
    "        # 多头自注意力\n",
    "        self.attn = MultiHeadATTN(qestion_dim, image_dim, \n",
    "                                  attn_mfb_input_dim, attn_mfb_hidden_dim, \n",
    "                                  attn_num_head, attn_output_dim)\n",
    "        # 问题的对齐表示到融合表示空间的映射函数\n",
    "        self.q_feature_linear = nn.Sequential(\n",
    "                                    nn.Dropout(0.5),\n",
    "                                    nn.Linear(qestion_dim, fusion_q_feature_dim),\n",
    "                                    nn.ReLU()\n",
    "                                )\n",
    "        # MFB融合图文表示类\n",
    "        self.fusion = MFBFusion(attn_output_dim, fusion_q_feature_dim, fusion_mfb_hidden_dim, 2)\n",
    "        # 分类器\n",
    "        self.classifier_linear = nn.Sequential(\n",
    "                                    nn.Dropout(0.5),\n",
    "                                    nn.Linear(fusion_mfb_hidden_dim, num_classes)\n",
    "                                )\n",
    "    \n",
    "    def forward(self, imgs, quests, lengths):\n",
    "        # 初始输入\n",
    "        v_feature = imgs.contiguous().view(-1, 36, 2048)\n",
    "        q_emb = self.text_encoder.embedding(quests)\n",
    "        q_feature, _ = self.text_encoder.rnn(q_emb)\n",
    "        q_feature = self.text_encoder._select_last(q_feature, lengths)\n",
    "        # 利用注意力获得问题的对齐表示\n",
    "        align_q_feature, _ = self.attn(q_feature, v_feature)  # b*620\n",
    "        # 对原始文本表示进行变换\n",
    "        original_q_feature =  self.q_feature_linear(q_feature)\n",
    "        # 融合对齐前后的问题的表示\n",
    "        x = self.fusion(align_q_feature, original_q_feature)\n",
    "        # 分类\n",
    "        x = self.classifier_linear(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 定义损失函数\n",
    "\n",
    "模型的损失函数为KL散度损失，同时兼容回答为单一值和列表两种情形。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KLLoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(KLLoss, self).__init__()\n",
    "        self.loss = nn.KLDivLoss(reduction='batchmean')\n",
    "\n",
    "    def forward(self, input, target):\n",
    "        return self.loss(nn.functional.log_softmax(input), target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 选择优化方法\n",
    "\n",
    "我们选用Adam优化算法来更新模型参数，学习速率采用指数衰减方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "\n",
    "def get_optimizer(model, config):\n",
    "    return torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)\n",
    "    \n",
    "def get_lr_scheduler(optimizer):\n",
    "    \"\"\"学习速率指数衰减\"\"\"\n",
    "    return lr_scheduler.ExponentialLR(optimizer, 0.5 ** (1 / 50000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 评估指标\n",
    "\n",
    "这里实现了VQAv2数据集中最常用的评估指标——回答准确率。具体而言，如果模型给出的回答在人工标注的10个回答中出现了3次及以上，则该回答的准确率为1，出现两次和一次的准确率分别为2/3和1/3。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(data_loader, model):\n",
    "    model.eval()\n",
    "    device = next(model.parameters()).device\n",
    "    accs = []\n",
    "    for i, (imgs, questions, answers, lengths) in enumerate(data_loader):\n",
    "        imgs = imgs.to(device)\n",
    "        questions = questions.to(device)\n",
    "        answers = answers.to(device)\n",
    "        lengths = lengths.to(device)\n",
    "        \n",
    "        output = model(imgs, questions, lengths)\n",
    "        hit_cts = answers[torch.arange(output.size(0)),output.argmax(dim=1)]\n",
    "        for hit_ct in hit_cts:\n",
    "            accs.append(min(1, hit_ct / 3.0))\n",
    "    model.train()\n",
    "    return float(sum(accs))/len(accs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 训练模型\n",
    "\n",
    "训练模型过程可以分为读取数据、前馈计算、计算损失、更新参数、选择模型五个步骤。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "# 设置模型超参数和辅助变量\n",
    "config = Namespace(\n",
    "    question_dim = 2400, \n",
    "    image_dim = 2048, \n",
    "    attn_mfb_input_dim = 310, \n",
    "    attn_mfb_hidden_dim = 510,\n",
    "    attn_num_head = 2, \n",
    "    attn_output_dim = 620,\n",
    "    fusion_q_feature_dim = 310,\n",
    "    fusion_mfb_hidden_dim = 510,\n",
    "    num_ans = 1000,\n",
    "    batch_size = 128,\n",
    "    learning_rate = 0.0001,\n",
    "    margin = 0.2,\n",
    "    num_epochs = 45,\n",
    "    grad_clip = 0.25,\n",
    "    evaluate_step = 360, # 每隔多少步在验证集上测试一次\n",
    "    checkpoint = None, # 如果不为None，则利用该变量路径的模型继续训练\n",
    "    best_checkpoint = '../model/mfb/best_vqa2.ckpt', # 验证集上表现最优的模型的路径\n",
    "    last_checkpoint = '../model/mfb/last_vqa2.ckpt' # 训练完成时的模型的路径\n",
    ")\n",
    "\n",
    "# 设置GPU信息\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") \n",
    "\n",
    "# 数据\n",
    "data_dir = '../data/vqa/vqa2/'\n",
    "dir_processed = os.path.join(data_dir, 'processed')\n",
    "\n",
    "train_loader, valid_loader = mktrainval(data_dir, \n",
    "               '../data/vqa/coco/image_box_features/', \n",
    "               config.batch_size, \n",
    "               workers=0)\n",
    "\n",
    "# 模型\n",
    "vocab = json.load(open(pjoin(dir_processed, 'vocab.json'), 'r'))  \n",
    "# 随机初始化 或 载入已训练的模型\n",
    "start_epoch = 0\n",
    "checkpoint = config.checkpoint\n",
    "if checkpoint is None:\n",
    "    model = MFBVQAModel(vocab['ques_vocab'], \n",
    "                        config.question_dim, \n",
    "                        config.image_dim, \n",
    "                        config.attn_mfb_input_dim, \n",
    "                        config.attn_mfb_hidden_dim, \n",
    "                        config.attn_num_head, \n",
    "                        config.attn_output_dim, \n",
    "                        config.fusion_q_feature_dim, \n",
    "                        config.fusion_mfb_hidden_dim,\n",
    "                        config.num_ans)\n",
    "else:\n",
    "    checkpoint = torch.load(checkpoint)\n",
    "    start_epoch = checkpoint['epoch'] + 1\n",
    "    model = checkpoint['model']\n",
    "\n",
    "# 优化器\n",
    "optimizer = get_optimizer(model, config)\n",
    "lrscheduler = get_lr_scheduler(optimizer)\n",
    "\n",
    "# 将模型拷贝至GPU，并开启训练模式\n",
    "model.to(device)\n",
    "model.train()\n",
    "\n",
    "# 损失函数\n",
    "loss_fn = KLLoss().to(device)\n",
    "\n",
    "best_res = 0\n",
    "print(\"开始训练\")\n",
    "fw = open('log.txt', 'w')\n",
    "for epoch in range(start_epoch, config.num_epochs):\n",
    "    for i, (imgs, questions, answers, lengths) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        # 1. 读取数据至GPU\n",
    "        imgs = imgs.to(device)\n",
    "        questions = questions.to(device)\n",
    "        answers = answers.to(device)\n",
    "        lengths = lengths.to(device)\n",
    "\n",
    "        # 2. 前馈计算\n",
    "        output = model(imgs, questions, lengths)\n",
    "        # 3. 计算损失\n",
    "        loss = loss_fn(output, answers)\n",
    "        loss.backward()\n",
    "        \n",
    "        # 梯度截断\n",
    "        if config.grad_clip > 0:\n",
    "            nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)\n",
    "\n",
    "        # 4. 更新参数\n",
    "        optimizer.step()\n",
    "        lrscheduler.step()\n",
    "\n",
    "        state = {\n",
    "                'epoch': epoch,\n",
    "                'step': i,\n",
    "                'model': model,\n",
    "                'optimizer': optimizer\n",
    "                }\n",
    "        \n",
    "        if (i+1) % config.evaluate_step == 0:\n",
    "            acc = evaluate(valid_loader, model)\n",
    "            # 5. 选择模型\n",
    "            if best_res < acc:\n",
    "                best_res = acc\n",
    "                torch.save(state, config.best_checkpoint)\n",
    "            torch.save(state, config.last_checkpoint)\n",
    "            print('epoch: %d, step: %d, loss: %.2f, \\\n",
    "                  ACC: %.3f' % \n",
    "                  (epoch, i+1, loss.item(), acc))\n",
    "            fw.write('epoch: %d, step: %d, loss: %.2f, \\\n",
    "                  ACC: %.3f' % \n",
    "                  (epoch, i+1, loss.item(), acc)+'\\n')\n",
    "fw.close()"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "rise": {
   "enable_chalkboard": true,
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
